import dictionary_learning.utils as utils
import sae
import torch
import torch.nn.functional as F
from training import train_meta_saes
from sae import BatchTopKSAE
from config import get_default_cfg, post_init_cfg
import json

device = 'cuda'

meta_saes = []
cfgs = []
saes = []

CORE_PATH = '../SAEBench/eval_results/core' # put here path to folder with SAEBench core results for sae
BATCH_TOP_K_PATH = '/home/user/sae_models/batch_topk' # put here path to folder with saes
ORT_PATH = '/home/user/sae_models/ort' # put here path to folder with saes
STANDARD_NEW_PATH = '/home/user/sae_models/standard_new' # put here path to folder with saes
matr_path = '/home/user/sae_models/matr' # put here path to folder with saes

batch_top_k_paths = utils.get_nested_folders(BATCH_TOP_K_PATH)
batch_top_k_ort_paths = utils.get_nested_folders(ORT_PATH) 
standard_new_paths = utils.get_nested_folders(STANDARD_NEW_PATH) 
matr_paths = utils.get_nested_folders(matr_path) 

for sae_path in batch_top_k_paths:
    sae, config = utils.load_dictionary(sae_path, device)
    state_dict = torch.load(sae_path + '/ae.pt')
    sae.W_dec = sae.decoder.weight.data.T
    json_filename = "path_to_SAEBench_core_metrics_of_model" # put here path to json with SAEBench core results for sae
    with open(json_filename, 'r') as f:
        eval_data = json.load(f)
    valid_indices = [feature['index'] for feature in eval_data['eval_result_details'] if feature['feature_density'] > 1e-7]
    sae.W_dec = sae.W_dec[valid_indices, :]

    saes.append(sae)
    cfg_copy = get_default_cfg()
    cfg_copy['model_name'] = config['trainer']['wandb_name']
    cfg_copy["sae_type"] = 'batch_topk'
    cfg_copy["dict_size"] = sae.W_dec.shape[0] // 4
    cfg_copy["top_k"] = 4
    cfg_copy["aux_penalty"] = 1/32
    cfg_copy["lr"] = 1.2e-2
    cfg_copy['batch_size'] = 2**12
    cfg_copy["input_unit_norm"] = False
    cfg_copy['wandb_project'] = 'meta_sae3.4'
    cfg_copy['act_size'] = 2304
    cfg_copy['device'] = 'cuda'
    cfg_copy['bandwidth'] = 0.001
    cfg_copy["num_tokens"] = 8000*cfg_copy['batch_size']
    cfg_copy["model_dtype"] = torch.bfloat16
    cfg_copy['perf_log_freq_base_metrics'] = 1
    cfg_copy = post_init_cfg(cfg_copy)
    meta_sae = BatchTopKSAE(cfg_copy)
    meta_saes.append(meta_sae)
    cfgs.append(cfg_copy)

for sae_path in batch_top_k_ort_paths:
    print(sae_path)
    sae, config = utils.load_dictionary(sae_path, device)
    state_dict = torch.load(sae_path + '/ae.pt')
    sae.W_dec = sae.decoder.weight.data.T

    json_filename = "path_to_SAEBench_core_metrics_of_model" # put here path to json with SAEBench core results for sae
    with open(json_filename, 'r') as f:
        eval_data = json.load(f)
    valid_indices = [feature['index'] for feature in eval_data['eval_result_details'] if feature['feature_density'] > 1e-7]
    sae.W_dec = sae.W_dec[valid_indices, :]

    saes.append(sae)
    cfg_copy = get_default_cfg()
    cfg_copy['model_name'] = 'add_exps_0_' + config['trainer']['wandb_name']
    cfg_copy["sae_type"] = 'batch_topk'
    cfg_copy["dict_size"] = sae.W_dec.shape[0] // 4
    cfg_copy["top_k"] = 4
    cfg_copy["aux_penalty"] = 1/32
    cfg_copy["lr"] = 1.2e-2
    cfg_copy['batch_size'] = 2**12
    cfg_copy["input_unit_norm"] = False
    cfg_copy['wandb_project'] = 'meta_sae10.1'
    cfg_copy['act_size'] = 2304
    cfg_copy['device'] = 'cuda'
    cfg_copy['bandwidth'] = 0.001
    cfg_copy["num_tokens"] = 8000*cfg_copy['batch_size']
    cfg_copy["model_dtype"] = torch.bfloat16
    cfg_copy['perf_log_freq_base_metrics'] = 1
    cfg_copy = post_init_cfg(cfg_copy)
    meta_sae = BatchTopKSAE(cfg_copy)
    meta_saes.append(meta_sae)
    cfgs.append(cfg_copy)


for sae_path in standard_new_paths:
    sae, config = utils.load_dictionary(sae_path, device)
    state_dict = torch.load(sae_path + '/ae.pt')
    sae.W_dec = sae.decoder.weight.data.T
    json_filename = "path_to_SAEBench_core_metrics_of_model" # put here path to json with SAEBench core results for sae
    with open(json_filename, 'r') as f:
        eval_data = json.load(f)
    valid_indices = [feature['index'] for feature in eval_data['eval_result_details'] if feature['feature_density'] > 1e-7]
    sae.W_dec = sae.W_dec[valid_indices, :]
    saes.append(sae)
    cfg_copy = get_default_cfg()
    cfg_copy['model_name'] = config['trainer']['wandb_name']
    cfg_copy["sae_type"] = 'batch_topk'
    cfg_copy["dict_size"] = sae.W_dec.shape[0] // 4
    cfg_copy["top_k"] = 4
    cfg_copy["aux_penalty"] = 1/32
    cfg_copy["lr"] = 1.2e-2
    cfg_copy['batch_size'] = 2**12
    cfg_copy["input_unit_norm"] = False
    cfg_copy['wandb_project'] = 'meta_sae3.4'
    cfg_copy['act_size'] = 2304
    cfg_copy['device'] = 'cuda'
    cfg_copy['bandwidth'] = 0.001
    cfg_copy["num_tokens"] = 8000*cfg_copy['batch_size']
    cfg_copy["model_dtype"] = torch.bfloat16
    cfg_copy['perf_log_freq_base_metrics'] = 1
    cfg_copy = post_init_cfg(cfg_copy)
    meta_sae = BatchTopKSAE(cfg_copy)
    meta_saes.append(meta_sae)
    cfgs.append(cfg_copy)

for sae_path in matr_paths:
    sae, config = utils.load_dictionary(sae_path, device)
    state_dict = torch.load(sae_path + '/ae.pt')
    w_dec_tensor = sae.W_dec.data.clone() 
    del sae.W_dec
    sae.register_buffer('W_dec', w_dec_tensor)
    sae.W_dec = F.normalize(sae.W_dec, p=2, dim=1)
    json_filename = "path_to_SAEBench_core_metrics_of_model" # put here path to json with SAEBench core results for sae
    with open(json_filename, 'r') as f:
        eval_data = json.load(f)
    valid_indices = [feature['index'] for feature in eval_data['eval_result_details'] if feature['feature_density'] > 1e-7]
    sae.W_dec = sae.W_dec[valid_indices, :]

    saes.append(sae)
    cfg_copy = get_default_cfg()
    cfg_copy['model_name'] = config['trainer']['wandb_name']
    cfg_copy["sae_type"] = 'batch_topk'
    cfg_copy["dict_size"] = sae.W_dec.shape[0] // 4
    cfg_copy["top_k"] = 4
    cfg_copy["aux_penalty"] = 1/32
    cfg_copy["lr"] = 1.2e-2
    cfg_copy['batch_size'] = 2**12
    cfg_copy["input_unit_norm"] = False
    cfg_copy['wandb_project'] = 'meta_sae3.45'
    cfg_copy['act_size'] = 2304
    cfg_copy['device'] = 'cuda'
    cfg_copy['bandwidth'] = 0.001
    cfg_copy["num_tokens"] = 8000*cfg_copy['batch_size']
    cfg_copy["model_dtype"] = torch.bfloat16
    cfg_copy['perf_log_freq_base_metrics'] = 1
    cfg_copy = post_init_cfg(cfg_copy)
    meta_sae = BatchTopKSAE(cfg_copy)
    meta_saes.append(meta_sae)
    cfgs.append(cfg_copy)


sae, config = utils.load_dictionary(standard_new_paths[0], device)
state_dict = torch.load(standard_new_paths[0] + '/ae.pt')
matrix = torch.randn(2**16, 2304)
matrix = matrix / matrix.norm(dim=1, keepdim=True)
sae.W_dec = matrix.to('cuda')
saes.append(sae)
cfg_copy = get_default_cfg()
cfg_copy['model_name'] = 'random_baseline'
cfg_copy["sae_type"] = 'batch_topk'
cfg_copy["dict_size"] = sae.W_dec.shape[0] // 4
print(cfg_copy["dict_size"])
cfg_copy["top_k"] = 4
cfg_copy["aux_penalty"] = 1/32
cfg_copy["lr"] = 1.2e-2
cfg_copy['batch_size'] = 2**12
cfg_copy["input_unit_norm"] = False
cfg_copy['wandb_project'] = 'meta_sae3.4'
cfg_copy['act_size'] = 2304
cfg_copy['device'] = 'cuda'
cfg_copy['bandwidth'] = 0.001
cfg_copy["num_tokens"] = 8000*cfg_copy['batch_size']
cfg_copy["model_dtype"] = torch.bfloat16
cfg_copy['perf_log_freq_base_metrics'] = 1
cfg_copy = post_init_cfg(cfg_copy)
meta_sae = BatchTopKSAE(cfg_copy)
meta_saes.append(meta_sae)
cfgs.append(cfg_copy)


train_meta_saes(saes, meta_saes, cfgs)